import enum
import json
from abc import ABC, abstractmethod
from datetime import timedelta
from enum import Enum
from functools import cached_property, partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Literal

from pydantic import BaseModel, ConfigDict, Field, computed_field, field_validator, model_validator
from pydantic_core.core_schema import ValidationInfo

from dynamiq.utils import generate_uuid
from dynamiq.utils.env import get_env_var
from dynamiq.utils.logger import logger

if TYPE_CHECKING:
    from chromadb import ClientAPI as ChromaClient
    from openai import OpenAI as OpenAIClient
    from pinecone import Pinecone as PineconeClient
    from qdrant_client import QdrantClient
    from weaviate import WeaviateClient


class HTTPMethod(str, enum.Enum):
    """
    This enum defines various method types for different HTTP requests.
    """

    GET = "GET"
    POST = "POST"
    PUT = "PUT"
    DELETE = "DELETE"
    PATCH = "PATCH"


class BaseConnection(BaseModel, ABC):
    """Represents a base connection class.

    This class should be subclassed to provide specific implementations for different types of
    connections.

    Attributes:
        id (str): A unique identifier for the connection, generated using `generate_uuid`.
        type (ConnectionType): The type of connection.
    """
    id: str = Field(default_factory=generate_uuid)

    @computed_field
    @cached_property
    def type(self) -> str:
        return f"{self.__module__.rsplit('.', 1)[0]}.{self.__class__.__name__}"

    @property
    def conn_params(self) -> dict:
        """
        Returns the parameters required for connection.

        Returns:
            dict: An empty dictionary.
        """
        return {}

    def to_dict(self, for_tracing: bool = False, **kwargs) -> dict:
        """Converts the connection instance to a dictionary.

        Returns:
            dict: A dictionary representation of the connection instance.
        """
        if for_tracing:
            return {"id": self.id, "type": self.type}
        else:
            return self.model_dump(**kwargs)

    @abstractmethod
    def connect(self):
        """Connects to the service.

        This method should be implemented by subclasses to establish a connection to the service.

        Raises:
            NotImplementedError: If the method is not implemented by a subclass.
        """
        raise NotImplementedError


class BaseApiKeyConnection(BaseConnection):
    """
    Represents a base connection class that uses an API key for authentication.

    Attributes:
        api_key (str): The API key used for authentication.
    """
    api_key: str

    @abstractmethod
    def connect(self):
        """
        Connects to the service.

        This method should be implemented by subclasses to establish a connection to the service using
        the provided API key.

        Raises:
            NotImplementedError: If the method is not implemented by a subclass.
        """
        raise NotImplementedError

    @property
    def conn_params(self) -> dict:
        """
        Returns the parameters required for connection.

        Returns:
            dict: A dictionary containing the API key with the key 'api_key'.
        """
        return {"api_key": self.api_key}


class HttpApiKey(BaseApiKeyConnection):
    """
    Represents a connection to an API that uses an HTTP API key for authentication.

    Attributes:
        url (str): The URL of the API.
    """

    url: str

    def connect(self):
        """
        Connects to the API.

        This method establishes a connection to the API using the provided URL and returns a requests
        session.

        Returns:
            requests: A requests module for making HTTP requests to the API.
        """
        import requests

        return requests

    @property
    def conn_params(self) -> dict:
        """
        Returns the parameters required for connection.

        Returns:
            dict: A dictionary containing the API key with the key 'api_key' and base url with the key 'api_base'.
        """
        return {
            "api_base": self.url,
            "api_key": self.api_key,
        }


class Dynamiq(HttpApiKey):
    """
    Represents a connection to the Dynamiq service.

    The base URL and API key can be provided explicitly or sourced from the
    ``DYNAMIQ_URL`` and ``DYNAMIQ_API_KEY`` environment variables.
    """

    url: str = Field(
        default_factory=partial(get_env_var, "DYNAMIQ_URL", "https://api.getdynamiq.ai")
    )
    api_key: str = Field(default_factory=partial(get_env_var, "DYNAMIQ_API_KEY"))
    headers: dict[str, Any] = Field(default_factory=dict)

    @model_validator(mode="after")
    def setup_headers(self):
        """Ensure bearer token is included in default headers."""
        if self.api_key:
            self.headers.update({"Authorization": f"Bearer {self.api_key}"})
        return self

    @property
    def conn_params(self) -> dict:
        params = super().conn_params.copy()
        if self.headers:
            params["headers"] = self.headers.copy()
        return params


class Http(BaseConnection):
    """
    Represents a connection to an API.

    Attributes:
        url (str): The URL of the API.
        method (str): HTTP method used for the request, defaults to HTTPMethod.POST.
        headers (dict[str, Any]): Additional headers to include in the request, defaults to an empty dictionary.
        params (Optional[dict[str, Any]]): Parameters to include in the request, defaults to an empty dictionary.
        data (Optional[dict[str, Any]]): Data to include in the request, defaults to an empty dictionary.
    """

    url: str = ""
    method: HTTPMethod
    headers: dict[str, Any] = Field(default_factory=dict)
    params: dict[str, Any] | None = Field(default_factory=dict)
    data: dict[str, Any] | None = Field(default_factory=dict)

    def connect(self):
        """
        Connects to the API.

        This method establishes a connection to the API using the provided URL and returns a requests
        session.

        Returns:
            requests: A requests module for making HTTP requests to the API.
        """
        import requests

        return requests


class OpenAI(BaseApiKeyConnection):
    """
    Represents a connection to the OpenAI service.

    Attributes:
        api_key (str): The API key for the OpenAI service, fetched from the environment variable 'OPENAI_API_KEY'.
        url (str): The endpoint url for the OpenAI service, fetched from the environment variable 'OPENAI_URL'.
    """
    api_key: str = Field(default_factory=partial(get_env_var, "OPENAI_API_KEY"))
    url: str = Field(default_factory=partial(get_env_var, "OPENAI_URL", "https://api.openai.com/v1"))

    def connect(self) -> "OpenAIClient":
        """
        Connects to the OpenAI service.

        This method establishes a connection to the OpenAI service using the provided API key.

        Returns:
            OpenAIClient: An instance of the OpenAIClient connected with the specified API key.
        """
        # Import in runtime to save memory
        from openai import OpenAI as OpenAIClient

        openai_client = OpenAIClient(api_key=self.api_key, base_url=self.url)
        logger.debug("Connected to OpenAI")
        return openai_client

    @property
    def conn_params(self) -> dict:
        """
        Returns the parameters required for connection.
        """
        return {
            "api_base": self.url,
            "api_key": self.api_key,
        }


class Anthropic(BaseApiKeyConnection):
    api_key: str = Field(default_factory=partial(get_env_var, "ANTHROPIC_API_KEY"))

    def connect(self):
        pass


class AWS(BaseConnection):
    access_key_id: str | None = Field(
        default_factory=partial(get_env_var, "AWS_ACCESS_KEY_ID")
    )
    secret_access_key: str | None = Field(
        default_factory=partial(get_env_var, "AWS_SECRET_ACCESS_KEY")
    )
    region: str = Field(default_factory=partial(get_env_var, "AWS_DEFAULT_REGION"))
    profile: str | None = Field(default_factory=partial(get_env_var, "AWS_DEFAULT_PROFILE"))

    def connect(self):
        pass

    @property
    def conn_params(self) -> dict:
        """Return parameters with aws_ prefix for compatibility with other systems"""
        params = {}
        if self.profile:
            params["aws_profile_name"] = self.profile
            params["aws_region_name"] = self.region
        else:
            params["aws_access_key_id"] = self.access_key_id
            params["aws_secret_access_key"] = self.secret_access_key
            params["aws_region_name"] = self.region
        return params

    def get_boto3_session(self):
        """Create and return a boto3.Session with properly formatted parameters"""
        import boto3
        params = {}
        if self.profile:
            params["profile_name"] = self.profile
        elif self.access_key_id and self.secret_access_key:
            params["aws_access_key_id"] = self.access_key_id
            params["aws_secret_access_key"] = self.secret_access_key
        if self.region:
            params["region_name"] = self.region
        return boto3.Session(**params)


class Gemini(BaseApiKeyConnection):
    api_key: str = Field(default_factory=partial(get_env_var, "GEMINI_API_KEY"))

    def connect(self):
        pass


class GoogleCloud(BaseConnection):
    """
    Represents a connection to Google Cloud Platform (GCP) using service account credentials.

    Attributes:
        project_id (str): The GCP project ID.
        private_key_id (str): The private key ID used for authentication.
        private_key (str): The private key used for secure access.
        client_email (str): The service account email address.
        client_id (str): The unique client ID for authentication.
        auth_uri (str): The URI for Google's authentication endpoint.
        token_uri (str): The URI for obtaining OAuth tokens.
        auth_provider_x509_cert_url (str): The URL for Google's authentication provider X.509 certificates.
        client_x509_cert_url (str): The URL for the client's X.509 certificate.
        universe_domain (str): The domain associated with the Google Cloud environment.
    """

    project_id: str = Field(default_factory=partial(get_env_var, "GOOGLE_CLOUD_PROJECT_ID"))
    private_key_id: str = Field(default_factory=partial(get_env_var, "GOOGLE_CLOUD_PRIVATE_KEY_ID"))
    private_key: str = Field(default_factory=partial(get_env_var, "GOOGLE_CLOUD_PRIVATE_KEY"))
    client_email: str = Field(default_factory=partial(get_env_var, "GOOGLE_CLOUD_CLIENT_EMAIL"))
    client_id: str = Field(default_factory=partial(get_env_var, "GOOGLE_CLOUD_CLIENT_ID"))
    auth_uri: str = Field(default_factory=partial(get_env_var, "GOOGLE_CLOUD_AUTH_URI"))
    token_uri: str = Field(default_factory=partial(get_env_var, "GOOGLE_CLOUD_TOKEN_URI"))
    auth_provider_x509_cert_url: str = Field(
        default_factory=partial(get_env_var, "GOOGLE_CLOUD_AUTH_PROVIDER_X509_CERT_URL")
    )
    client_x509_cert_url: str = Field(default_factory=partial(get_env_var, "GOOGLE_CLOUD_CLIENT_X509_CERT_URL"))
    universe_domain: str = Field(default_factory=partial(get_env_var, "GOOGLE_CLOUD_UNIVERSE_DOMAIN"))

    def connect(self):
        pass

    @property
    def conn_params(self):
        """
        Returns the parameters required for the connection.

        This property returns a dictionary containing Google Cloud service account credentials.

        Returns:
            dict: A dictionary with the keys 'vertex_project' and 'vertex_location'.
        """
        return {
            "project_id": self.project_id,
            "private_key_id": self.private_key_id,
            "private_key": self.private_key,
            "client_email": self.client_email,
            "client_id": self.client_id,
            "client_x509_cert_url": self.client_x509_cert_url,
            "auth_uri": self.auth_uri,
            "token_uri": self.token_uri,
            "auth_provider_x509_cert_url": self.auth_provider_x509_cert_url,
            "universe_domain": self.universe_domain,
        }


class VertexAI(GoogleCloud):
    """
    Represents a connection to the Vertex AI service.

    This connection requires additional GCP application credentials. The credentials should be provided in the
    connection fields (related to Google Cloud) or set in the environment variables.

    Attributes:
        vertex_project_id (str): The GCP project ID.
        vertex_project_location (str): The location of the GCP project.
    """

    vertex_project_id: str = Field(default_factory=partial(get_env_var, "VERTEXAI_PROJECT_ID"))
    vertex_project_location: str = Field(default_factory=partial(get_env_var, "VERTEXAI_PROJECT_LOCATION"))

    def connect(self):
        pass

    @property
    def conn_params(self):
        """
        Returns the parameters required for the connection.

        This property returns a dictionary containing the project ID and project location.

        Returns:
            dict: A dictionary with the keys 'vertex_project' and 'vertex_location'.
        """
        vertex_credentials = json.dumps(super().conn_params.copy())
        return {
            "vertex_project": self.vertex_project_id,
            "vertex_location": self.vertex_project_location,
            "vertex_credentials": vertex_credentials,
        }


class Cohere(BaseApiKeyConnection):
    api_key: str = Field(default_factory=partial(get_env_var, "COHERE_API_KEY"))

    def connect(self):
        pass


class Mistral(BaseApiKeyConnection):
    api_key: str = Field(default_factory=partial(get_env_var, "MISTRAL_API_KEY"))

    def connect(self):
        pass


class Whisper(Http):
    """
    Represents a connection to the Whisper API using an HTTP request.

    Attributes:
        url (str): URL of the Whisper API, fetched from the environment variable "WHISPER_URL".
        method (str): HTTP method used for the request, defaults to HTTPMethod.POST.
        api_key (str): API key for authentication, fetched from the environment variable "OPENAI_API_KEY".
    """
    url: str = Field(
        default_factory=partial(
            get_env_var, "WHISPER_URL", "https://api.openai.com/v1/"
        )
    )
    method: str = HTTPMethod.POST
    api_key: str = Field(default_factory=partial(get_env_var, "OPENAI_API_KEY"))

    @model_validator(mode="after")
    def setup_headers(self):
        """Setup headers after model validation."""
        if self.api_key:
            self.headers.update({"Authorization": f"Bearer {self.api_key}"})
        return self


class ElevenLabs(Http):
    """
    Represents a connection to the ElevenLabs API using an HTTP request.

    Attributes:
        url (str): URL of the ElevenLabs API.
        method (str): HTTP method used for the request, defaults to HTTPMethod.POST.
        api_key (str): API key for authentication, fetched from the environment variable "ELEVENLABS_API_KEY".
    """

    url: str = Field(
        default_factory=partial(
            get_env_var,
            "ELEVENLABS_URL",
            "https://api.elevenlabs.io/v1/",
        )
    )
    method: str = HTTPMethod.POST
    api_key: str = Field(default_factory=partial(get_env_var, "ELEVENLABS_API_KEY"))

    @model_validator(mode="after")
    def setup_headers(self):
        """Setup headers after model validation."""
        if self.api_key:
            self.headers.update({"xi-api-key": self.api_key})
        return self


class Pinecone(BaseApiKeyConnection):
    """
    Represents a connection to the Pinecone service.

    Attributes:
        api_key (str): The API key for the service.
            Defaults to the environment variable 'PINECONE_API_KEY'.
    """

    api_key: str = Field(default_factory=partial(get_env_var, "PINECONE_API_KEY"))

    def connect(self) -> "PineconeClient":
        """
        Connects to the Pinecone service.

        This method establishes a connection to the Pinecone service using the provided API key.

        Returns:
            PineconeClient: An instance of the PineconeClient connected to the service.
        """
        # Import in runtime to save memory
        from pinecone import Pinecone as PineconeClient
        pinecone_client = PineconeClient(self.api_key)
        logger.debug("Connected to Pinecone")
        return pinecone_client


class Qdrant(BaseApiKeyConnection):
    """
    Represents a connection to the Qdrant service.

    Attributes:
        url (str): The URL of the Qdrant service.
            Defaults to the environment variable 'QDRANT_URL'.
        api_key (str): The API key for the Qdrant service.
            Defaults to the environment variable 'QDRANT_API_KEY'.
    """

    url: str = Field(default_factory=partial(get_env_var, "QDRANT_URL"))
    api_key: str = Field(default_factory=partial(get_env_var, "QDRANT_API_KEY"))

    def connect(self) -> "QdrantClient":
        from qdrant_client import QdrantClient

        qdrant_client = QdrantClient(
            url=self.url,
            api_key=self.api_key,
        )

        return qdrant_client


class WeaviateDeploymentType(str, enum.Enum):
    """
    Defines various deployment types for different Weaviate deployments.

    Attributes:
        WEAVIATE_CLOUD (str): Represents a deployment on Weaviate Cloud.
            Value is 'weaviate_cloud'.
        CUSTOM (str): Represents a custom deployment.
            Value is 'custom'.
    """

    WEAVIATE_CLOUD = "weaviate_cloud"
    CUSTOM = "custom"


class Weaviate(BaseApiKeyConnection):
    """
    Represents a connection to the Weaviate service.

    Attributes:
        deployment_type (WeaviateDeploymentType): The deployment type of the service.
        api_key (str): The API key for the service.
            Defaults to the environment variable 'WEAVIATE_API_KEY'.
        url (str): The URL of the service.
            Defaults to the environment variable 'WEAVIATE_URL'.
        http_host (str): The HTTP host for the service.
            Defaults to the environment variable 'WEAVIATE_HTTP_HOST'.
        http_port (int): The HTTP port for the service.
            Defaults to the environment variable 'WEAVIATE_HTTP_PORT'.
        grpc_host (str): The gRPC host for the service.
            Defaults to the environment variable 'WEAVIATE_GRPC_HOST'.
        grpc_port (int): The gRPC port for the service.
            Defaults to the environment variable 'WEAVIATE_GRPC_PORT'.
    """

    deployment_type: WeaviateDeploymentType = WeaviateDeploymentType.WEAVIATE_CLOUD
    api_key: str = Field(default_factory=partial(get_env_var, "WEAVIATE_API_KEY"))
    url: str = Field(default_factory=partial(get_env_var, "WEAVIATE_URL"))
    http_host: str = Field(default_factory=partial(get_env_var, "WEAVIATE_HTTP_HOST"))
    http_port: int = Field(default_factory=partial(get_env_var, "WEAVIATE_HTTP_PORT", 443))
    grpc_host: str = Field(default_factory=partial(get_env_var, "WEAVIATE_GRPC_HOST"))
    grpc_port: int = Field(default_factory=partial(get_env_var, "WEAVIATE_GRPC_PORT", 50051))

    def connect(self) -> "WeaviateClient":
        """
        Connects to the Weaviate service.

        This method establishes a connection to the Weaviate service using the provided URL and API key.

        Returns:
            WeaviateClient: An instance of the WeaviateClient connected to the specified URL.
        """
        # Import in runtime to save memory
        from weaviate import connect_to_custom, connect_to_weaviate_cloud
        from weaviate.classes.init import AdditionalConfig, Auth, Timeout

        if self.deployment_type == WeaviateDeploymentType.WEAVIATE_CLOUD:
            weaviate_client = connect_to_weaviate_cloud(
                cluster_url=self.url,
                auth_credentials=Auth.api_key(self.api_key),
            )
            logger.debug(f"Connected to Weaviate with url={self.url}")
            return weaviate_client

        elif self.deployment_type == WeaviateDeploymentType.CUSTOM:
            weaviate_client = connect_to_custom(
                http_host=self.http_host,
                http_port=self.http_port,
                http_secure=True,
                grpc_host=self.grpc_host,
                grpc_port=self.grpc_port,
                grpc_secure=True,
                auth_credentials=Auth.api_key(self.api_key),
                additional_config=AdditionalConfig(
                    timeout=Timeout(init=30, query=60, insert=120),  # Values in seconds
                ),
                skip_init_checks=False,
            )
            logger.debug(f"Connected to Weaviate with http_host={self.http_host}")
            return weaviate_client
        else:
            raise ValueError("Invalid deployment type")


class Chroma(BaseConnection):
    """
    Represents a connection to the Chroma service.

    Attributes:
        host (str): The host address of the Chroma service, fetched from the environment variable 'CHROMA_HOST'.
        port (int): The port number of the Chroma service, fetched from the environment variable 'CHROMA_PORT'.
    """

    host: str = Field(default_factory=partial(get_env_var, "CHROMA_HOST"))
    port: int = Field(default_factory=partial(get_env_var, "CHROMA_PORT"))

    @property
    def vector_store_cls(self):
        """
        Returns the ChromaVectorStore class.

        This property dynamically imports and returns the ChromaVectorStore class
        from the 'dynamiq.storages.vector' module.

        Returns:
            type: The ChromaVectorStore class.
        """
        from dynamiq.storages.vector import ChromaVectorStore

        return ChromaVectorStore

    def connect(self) -> "ChromaClient":
        """
        Connects to the Chroma service.

        This method establishes a connection to the Chroma service using the provided host and port.

        Returns:
            ChromaClient: An instance of the ChromaClient connected to the specified host and port.
        """
        # Import in runtime to save memory
        from chromadb import HttpClient

        chroma_client = HttpClient(host=self.host, port=self.port)
        logger.debug(f"Connected to Chroma with host={self.host} and port={str(self.port)}")
        return chroma_client


class Unstructured(HttpApiKey):
    """
    Represents a connection to the Unstructured API.

    Attributes:
        url (str): The URL of the Unstructured API, fetched from the environment variable 'UNSTRUCTURED_API_URL'.
        api_key (str): The API key for the Unstructured API, fetched from the environment
            variable 'UNSTRUCTURED_API_KEY'.
    """

    url: str = Field(
        default_factory=partial(
            get_env_var,
            "UNSTRUCTURED_API_URL",
            "https://api.unstructured.io/",
        )
    )
    api_key: str = Field(default_factory=partial(get_env_var, "UNSTRUCTURED_API_KEY"))

    def connect(self):
        """
        Connects to the Unstructured API.
        """
        pass


class Tavily(Http):
    url: str = Field(default="https://api.tavily.com")
    api_key: str = Field(default_factory=partial(get_env_var, "TAVILY_API_KEY"))
    method: Literal[HTTPMethod.POST] = HTTPMethod.POST

    @model_validator(mode="after")
    def setup_data(self):
        """Setup data after model validation."""
        if self.api_key:
            self.data.update({"api_key": self.api_key})
        return self


class ScaleSerp(Http):
    """
    Connection class for Scale SERP Search API.
    """

    url: str = "https://api.scaleserp.com"
    api_key: str = Field(default_factory=partial(get_env_var, "SERP_API_KEY"))
    method: str = HTTPMethod.GET

    model_config = ConfigDict(arbitrary_types_allowed=True)

    @model_validator(mode="after")
    def setup_params(self):
        """Setup params after model validation."""
        if self.api_key:
            self.params.update({"api_key": self.api_key})
        return self


class ZenRows(Http):
    """
    Connection class for ZenRows Scrape API.
    """

    url: str = "https://api.zenrows.com/v1/"
    api_key: str = Field(default_factory=partial(get_env_var, "ZENROWS_API_KEY"))
    method: str = HTTPMethod.GET

    @model_validator(mode="after")
    def setup_params(self):
        """Setup params after model validation."""
        if self.api_key:
            self.params.update({"apikey": self.api_key})
        return self


class Groq(BaseApiKeyConnection):
    api_key: str = Field(default_factory=partial(get_env_var, "GROQ_API_KEY"))

    def connect(self):
        pass


class TogetherAI(BaseApiKeyConnection):
    api_key: str = Field(default_factory=partial(get_env_var, "TOGETHER_API_KEY"))

    def connect(self):
        pass


class Anyscale(BaseApiKeyConnection):
    api_key: str = Field(default_factory=partial(get_env_var, "ANYSCALE_API_KEY"))

    def connect(self):
        pass


class Firecrawl(Http):
    url: str = Field(default="https://api.firecrawl.dev/v2/")
    api_key: str = Field(default_factory=lambda: get_env_var("FIRECRAWL_API_KEY"))
    method: Literal[HTTPMethod.POST] = HTTPMethod.POST

    @model_validator(mode="after")
    def setup_headers(self):
        """Setup authorization headers after model validation."""
        if self.api_key:
            self.headers.update({"Authorization": f"Bearer {self.api_key}"})
        return self


class E2B(BaseApiKeyConnection):
    api_key: str = Field(default_factory=partial(get_env_var, "E2B_API_KEY"))
    domain: str | None = Field(default_factory=partial(get_env_var, "E2B_DOMAIN"))

    def connect(self):
        pass


class HuggingFace(BaseApiKeyConnection):
    api_key: str = Field(default_factory=partial(get_env_var, "HUGGINGFACE_API_KEY"))

    def connect(self):
        pass


class WatsonX(BaseApiKeyConnection):
    api_key: str = Field(default_factory=partial(get_env_var, "WATSONX_API_KEY"))
    project_id: str = Field(default_factory=partial(get_env_var, "WATSONX_PROJECT_ID"))
    url: str = Field(default_factory=partial(get_env_var, "WATSONX_URL"))

    def connect(self):
        pass

    @property
    def conn_params(self) -> dict:
        """
        Returns the parameters required for connection.

        Returns:
            dict: A dictionary containing

                -the API key with the key 'api_key'.

                -the project ID with the key 'project_id'.

                -the url with the key 'url'.
        """
        return {
            "api_key": self.api_key,
            "project_id": self.project_id,
            "api_base": self.url,
        }


class AzureAI(BaseApiKeyConnection):
    api_key: str = Field(default_factory=partial(get_env_var, "AZURE_API_KEY"))
    url: str = Field(default_factory=partial(get_env_var, "AZURE_URL"))
    api_version: str = Field(default_factory=partial(get_env_var, "AZURE_API_VERSION"))

    def connect(self):
        pass

    @property
    def conn_params(self) -> dict:
        """
        Returns the parameters required for connection.

        Returns:
            dict: A dictionary containing

                -the API key with the key 'api_key'.

                -the base url with the key 'api_base'.

                -the API version with the key 'api_version'.
        """
        return {
            "api_base": self.url,
            "api_key": self.api_key,
            "api_version": self.api_version,
        }


class DeepInfra(BaseApiKeyConnection):
    api_key: str = Field(default_factory=partial(get_env_var, "DEEPINFRA_API_KEY"))

    def connect(self):
        pass


class Cerebras(BaseApiKeyConnection):
    api_key: str = Field(default_factory=partial(get_env_var, "CEREBRAS_API_KEY"))

    def connect(self):
        pass


class Replicate(BaseApiKeyConnection):
    api_key: str = Field(default_factory=partial(get_env_var, "REPLICATE_API_KEY"))

    def connect(self):
        pass


class AI21(BaseApiKeyConnection):
    api_key: str = Field(default_factory=partial(get_env_var, "AI21_API_KEY"))

    def connect(self):
        pass


class SambaNova(BaseApiKeyConnection):
    api_key: str = Field(default_factory=partial(get_env_var, "SAMBANOVA_API_KEY"))

    def connect(self):
        pass


class MilvusDeploymentType(str, enum.Enum):
    """
    Defines general deployment types for Milvus deployments.
    Attributes:
        FILE (str): Represents a file-based deployment, validated with a .db suffix.
        HOST (str): Represents a host-based deployment, which could be a cloud, cluster,
                    or single machine with or without authentication.
    """

    FILE = "file"
    HOST = "host"


class Milvus(BaseConnection):
    """
    Represents a connection to the Milvus service.

    Attributes:
        deployment_type (MilvusDeploymentType): The deployment type of the Milvus service
        api_key (Optional[str]): The API key for Milvus
        uri (str): The URI for the Milvus instance (file path or host URL).
    """

    deployment_type: MilvusDeploymentType = MilvusDeploymentType.HOST
    uri: str = Field(default_factory=partial(get_env_var, "MILVUS_URI", "http://localhost:19530"))
    api_key: str | None = Field(default_factory=partial(get_env_var, "MILVUS_API_TOKEN", None))

    @field_validator("uri")
    @classmethod
    def validate_uri(cls, uri: str, values: ValidationInfo) -> str:
        deployment_type = values.data.get("deployment_type")

        if deployment_type == MilvusDeploymentType.FILE and not uri.endswith(".db"):
            raise ValueError("For FILE deployment, URI should point to a file ending with '.db'.")

        return uri

    def connect(self):
        from pymilvus import MilvusClient

        if self.deployment_type == MilvusDeploymentType.FILE:
            milvus_client = MilvusClient(uri=self.uri)

        elif self.deployment_type == MilvusDeploymentType.HOST:
            if self.api_key:
                milvus_client = MilvusClient(uri=self.uri, token=self.api_key)
            else:
                milvus_client = MilvusClient(uri=self.uri)

        else:
            raise ValueError("Invalid deployment type for Milvus connection.")

        return milvus_client


class Perplexity(BaseApiKeyConnection):
    api_key: str = Field(default_factory=partial(get_env_var, "PERPLEXITYAI_API_KEY"))

    def connect(self):
        pass


class DeepSeek(BaseApiKeyConnection):
    api_key: str = Field(default_factory=partial(get_env_var, "DEEPSEEK_API_KEY"))

    def connect(self):
        pass


class PostgreSQL(BaseConnection):
    host: str = Field(default_factory=partial(get_env_var, "POSTGRESQL_HOST", "localhost"))
    port: int = Field(default_factory=partial(get_env_var, "POSTGRESQL_PORT", 5432))
    database: str = Field(default_factory=partial(get_env_var, "POSTGRESQL_DATABASE", "db"))
    user: str = Field(default_factory=partial(get_env_var, "POSTGRESQL_USER", "postgres"))
    password: str = Field(default_factory=partial(get_env_var, "POSTGRESQL_PASSWORD", "password"))

    def connect(self):
        try:
            import psycopg

            conn = psycopg.connect(
                host=self.host,
                port=self.port,
                dbname=self.database,
                user=self.user,
                password=self.password,
                row_factory=psycopg.rows.dict_row,
            )
            conn.autocommit = True
            logger.debug(
                f"Connected to PostgreSQL with host={self.host}, "
                f"port={str(self.port)}, user={self.user}, "
                f"database={self.database}."
            )
            return conn
        except Exception as e:
            raise ConnectionError(f"Failed to connect to PostgreSQL: {str(e)}")

    @property
    def conn_params(self) -> str:
        """
        Returns the parameters required for connection.

        Returns:
            dict: A string containing the host, the port, the database,
            the user, and the password for the connection.
        """
        return f"postgresql://{self.user}:{self.password}@{self.host}:{self.port}/{self.database}"


class Exa(Http):
    """
    Represents a connection to the Exa AI Search API.

    Attributes:
        url (str): The URL of the Exa API.
        method (Literal[HTTPMethod.POST]): HTTP method used for the request, defaults to POST.
        api_key (str): The API key for authentication, fetched from the environment variable 'EXA_API_KEY'.
    """

    url: Literal["https://api.exa.ai"] = Field(default="https://api.exa.ai")
    method: Literal[HTTPMethod.POST] = HTTPMethod.POST
    api_key: str = Field(default_factory=partial(get_env_var, "EXA_API_KEY"))

    @model_validator(mode="after")
    def setup_headers(self):
        """Setup headers after model validation."""
        if self.api_key:
            self.headers.update({"x-api-key": self.api_key, "Content-Type": "application/json"})
        return self


class Ollama(BaseConnection):
    """Represents a connection to Ollama API.

    Attributes:
        url (str): The URL of the Ollama API, defaults to "http://localhost:11434".
    """

    url: str = Field(default="http://localhost:11434")

    def connect(self):
        """Connects to the Ollama API.

        Returns:
            requests: A requests module for making HTTP requests to the API.
        """
        import requests

        return requests

    @property
    def conn_params(self) -> dict:
        """Returns the parameters required for connection.

        Returns:
            dict: A dictionary containing the base url with the key 'api_base'.
        """
        return {
            "api_base": self.url,
        }


class Jina(Http):
    """
    Connection class for Jina Scrape API.
    """

    api_key: str = Field(default_factory=partial(get_env_var, "JINA_API_KEY"))
    method: Literal[HTTPMethod.GET] = HTTPMethod.GET

    @model_validator(mode="after")
    def setup_headers(self):
        """Setup headers after model validation."""
        if self.api_key:
            self.headers.update({"Authorization": f"Bearer {self.api_key}"})
        return self


class MySQL(BaseConnection):
    host: str = Field(default_factory=partial(get_env_var, "MYSQL_HOST", "localhost"))
    port: int = Field(default_factory=partial(get_env_var, "MYSQL_PORT", 3306))
    database: str = Field(default_factory=partial(get_env_var, "MYSQL_DATABASE", "db"))
    user: str = Field(default_factory=partial(get_env_var, "MYSQL_USER", "mysql"))
    password: str = Field(default_factory=partial(get_env_var, "MYSQL_PASSWORD", "password"))

    def connect(self):
        import mysql.connector

        try:
            conn = mysql.connector.connect(
                host=self.host, port=self.port, database=self.database, user=self.user, passwd=self.password
            )
            conn.autocommit = True
            logger.debug(
                f"Connected to MySQL with host={self.host}, " f"user={self.user}, " f"database={self.database}."
            )
            return conn
        except mysql.connector.Error as e:
            raise ConnectionError(f"Failed to connect to MySQL: {str(e)}")

    @property
    def cursor_params(self) -> dict:
        return {"dictionary": True}


class Snowflake(BaseConnection):
    user: str = Field(default_factory=partial(get_env_var, "SNOWFLAKE_USER", "snowflake"))
    password: str = Field(default_factory=partial(get_env_var, "SNOWFLAKE_PASSWORD", "password"))
    account: str = Field(default_factory=partial(get_env_var, "SNOWFLAKE_ACCOUNT", "account"))
    warehouse: str = Field(default_factory=partial(get_env_var, "SNOWFLAKE_WAREHOUSE", "warehouse"))
    database: str = Field(default_factory=partial(get_env_var, "SNOWFLAKE_DATABASE", "db"))
    snowflake_schema: str = Field(default_factory=partial(get_env_var, "SNOWFLAKE_SCHEMA", "schema"), alias="schema")

    def connect(self):
        try:
            import snowflake.connector

            conn = snowflake.connector.connect(
                user=self.user,
                password=self.password,
                account=self.account,
                warehouse=self.warehouse,
                database=self.database,
                schema=self.snowflake_schema,
            )
            logger.debug(
                f"Connected to Snowflake using account={self.account}, "
                f"warehouse={str(self.warehouse)}, user={self.user}, "
                f"database={self.database}, schema={self.snowflake_schema}."
            )
            return conn
        except Exception as e:
            raise ConnectionError(f"Failed to connect to Snowflake: {str(e)}")

    @property
    def cursor_params(self) -> dict:
        from snowflake.connector import DictCursor

        return {"cursor_class": DictCursor}


class AWSRedshift(BaseConnection):
    host: str = Field(default_factory=partial(get_env_var, "AWS_REDSHIFT_HOST"))
    port: int = Field(default_factory=partial(get_env_var, "AWS_REDSHIFT_PORT", 5439))
    database: str = Field(default_factory=partial(get_env_var, "AWS_REDSHIFT_DATABASE", "db"))
    user: str = Field(default_factory=partial(get_env_var, "AWS_REDSHIFT_USER", "awsuser"))
    password: str = Field(default_factory=partial(get_env_var, "AWS_REDSHIFT_PASSWORD", "password"))

    def connect(self):
        try:
            import psycopg

            conn = psycopg.connect(
                host=self.host,
                port=self.port,
                dbname=self.database,
                user=self.user,
                password=self.password,
                client_encoding="utf-8",
                row_factory=psycopg.rows.dict_row,
            )
            conn.autocommit = True
            logger.debug(
                f"Connected to Amazon Redshift with host={self.host}, "
                f"port={str(self.port)}, user={self.user}, "
                f"database={self.database}."
            )
            return conn
        except Exception as e:
            raise ConnectionError(f"Failed to connect to Amazon Redshift : {str(e)}")


class Elasticsearch(BaseConnection):
    """
    Represents a connection to the Elasticsearch service.

    Attributes:
        url (str): The URL of the Elasticsearch service
        api_key (str): API key for authentication
        username (str): Username for basic authentication
        password (str): Password for basic authentication
        cloud_id (str): Cloud ID for Elastic Cloud deployment
        ca_path (str): Path to CA certificate for SSL verification
        verify_certs (bool): Whether to verify SSL certificates
        use_ssl (bool): Whether to use SSL for connection
    """

    url: str = Field(default_factory=partial(get_env_var, "ELASTICSEARCH_URL", None))
    api_key_id: str | None = Field(default_factory=partial(get_env_var, "ELASTICSEARCH_API_KEY_ID", None))
    api_key: str | None = Field(default_factory=partial(get_env_var, "ELASTICSEARCH_API_KEY", None))
    username: str | None = Field(default_factory=partial(get_env_var, "ELASTICSEARCH_USERNAME", None))
    password: str | None = Field(default_factory=partial(get_env_var, "ELASTICSEARCH_PASSWORD", None))
    cloud_id: str | None = Field(default_factory=partial(get_env_var, "ELASTICSEARCH_CLOUD_ID", None))
    ca_path: str | None = Field(default_factory=partial(get_env_var, "ELASTICSEARCH_CA_PATH", None))
    verify_certs: bool = Field(default_factory=partial(get_env_var, "ELASTICSEARCH_VERIFY_CERTS", False))
    use_ssl: bool = Field(default_factory=partial(get_env_var, "ELASTICSEARCH_USE_SSL", True))

    def connect(self):
        """
        Connects to the Elasticsearch service.

        Returns:
            elasticsearch.Elasticsearch: An instance of the Elasticsearch client.

        Raises:
            ConnectionError: If connection fails
            ValueError: If neither API key nor basic auth credentials are provided
        """

        from elasticsearch import Elasticsearch
        from elasticsearch.exceptions import AuthenticationException

        # Build connection params
        conn_params = {}

        # Handle authentication
        if self.api_key is not None:
            if self.api_key_id is not None:
                conn_params["api_key"] = (self.api_key_id, self.api_key)
            else:
                conn_params["api_key"] = self.api_key
        elif self.username is not None and self.password is not None:
            conn_params["basic_auth"] = (self.username, self.password)
        elif self.cloud_id is None:  # Only require auth for non-cloud deployments
            raise ValueError("Either API key or username/password must be provided")

        # Handle SSL/TLS
        if self.use_ssl:
            if self.ca_path is not None:
                conn_params["ca_certs"] = self.ca_path
            conn_params["verify_certs"] = self.verify_certs

        # Handle cloud deployment
        if self.cloud_id is not None:
            conn_params["cloud_id"] = self.cloud_id
        else:
            conn_params["hosts"] = [self.url]

        # Create client
        try:
            es_client = Elasticsearch(**conn_params)
        except Exception as e:
            raise ConnectionError(f"Failed to connect to Elasticsearch: {str(e)}")

        if not es_client.ping():
            try:
                info = es_client.info()
            except AuthenticationException as e:
                info = f"Authentication error: {e}"
            raise ConnectionError(f"Failed to connect to Elasticsearch. {info}")

        logger.debug(f"Connected to Elasticsearch at {self.cloud_id or self.url}")
        return es_client


class AWSOpenSearch(AWS):
    """
    Represents a connection to AWS OpenSearch Service.

    Attributes:
        host (str): The OpenSearch domain endpoint host.
        port (int): The port number for the OpenSearch domain. Defaults to 443.
        region (str): AWS region where the OpenSearch domain is located.
        access_key_id (str): AWS access key ID for authentication.
        secret_access_key (str): AWS secret access key for authentication.
        profile (str): AWS profile name to use for authentication.
        service (str): AWS service name. Defaults to 'es' for Elasticsearch/OpenSearch.
        use_ssl (bool): Whether to use SSL for the connection. Defaults to True.
        verify_certs (bool): Whether to verify SSL certificates. Defaults to True.
        ca_certs (str): Path to CA certificates for SSL verification.
    """

    host: str = Field(default_factory=partial(get_env_var, "AWS_OPENSEARCH_HOST"))
    port: int = Field(default_factory=partial(get_env_var, "AWS_OPENSEARCH_PORT", 443))

    service: str = Field(default_factory=partial(get_env_var, "AWS_OPENSEARCH_SERVICE", "es"))
    use_ssl: bool = Field(default_factory=partial(get_env_var, "AWS_OPENSEARCH_USE_SSL", True))
    verify_certs: bool = Field(default_factory=partial(get_env_var, "AWS_OPENSEARCH_VERIFY_CERTS", True))
    ca_certs: str | None = Field(default_factory=partial(get_env_var, "AWS_OPENSEARCH_CA_CERTS"))

    def connect(self):
        """
        Connect to AWS OpenSearch using inherited AWS credentials/session logic.
        """
        from opensearchpy import OpenSearch, RequestsHttpConnection
        from requests_aws4auth import AWS4Auth

        session = self.get_boto3_session()
        credentials = session.get_credentials()

        if not credentials:
            raise ValueError("Missing AWS credentials: provide profile or access_key_id/secret_access_key")

        awsauth = AWS4Auth(
            credentials.access_key,
            credentials.secret_key,
            self.region,
            self.service,
            session_token=credentials.token,
        )

        conn_params = {
            "hosts": [{"host": self.host, "port": self.port}],
            "http_auth": awsauth,
            "use_ssl": self.use_ssl,
            "verify_certs": self.verify_certs,
            "connection_class": RequestsHttpConnection,
        }

        if self.ca_certs:
            conn_params["ca_certs"] = self.ca_certs

        try:
            client = OpenSearch(**conn_params)
        except Exception as e:
            raise ConnectionError(f"Failed to connect to AWS OpenSearch: {e}")

        if not client.ping():
            raise ConnectionError(f"Failed to ping AWS OpenSearch cluster at {self.host}:{self.port}")

        logger.debug(
            f"Connected to AWS OpenSearch at {self.host}:{self.port} (region={self.region}, profile={self.profile})"
        )
        return client

    @property
    def conn_params(self) -> dict:
        """
        Merge inherited AWS params with OpenSearch-specific settings.
        """
        params = super().conn_params

        params.update(
            {
                "hosts": [{"host": self.host, "port": self.port}],
                "service": self.service,
                "use_ssl": self.use_ssl,
                "verify_certs": self.verify_certs,
            }
        )

        if self.ca_certs:
            params["ca_certs"] = self.ca_certs

        return params


class xAI(BaseApiKeyConnection):
    api_key: str = Field(default_factory=partial(get_env_var, "XAI_API_KEY"))

    def connect(self):
        pass


class FireworksAI(BaseApiKeyConnection):
    api_key: str = Field(default_factory=partial(get_env_var, "FIREWORKS_AI_API_KEY"))

    def connect(self):
        pass


class NvidiaNIM(BaseApiKeyConnection):
    url: str = Field(default_factory=partial(get_env_var, "NVIDIA_NIM_URL"))
    api_key: str = Field(default_factory=partial(get_env_var, "NVIDIA_NIM_API_KEY"))

    def connect(self):
        pass

    @property
    def conn_params(self) -> dict:
        """
        Returns the parameters required for connection.
        """
        return {
            "api_base": self.url,
            "api_key": self.api_key,
        }


class Databricks(BaseApiKeyConnection):
    url: str = Field(default_factory=partial(get_env_var, "DATABRICKS_API_BASE"))
    api_key: str = Field(default_factory=partial(get_env_var, "DATABRICKS_API_KEY"))

    def connect(self):
        pass

    @property
    def conn_params(self) -> dict:
        """
        Returns the parameters required for connection.
        """
        return {
            "api_base": self.url,
            "api_key": self.api_key,
        }


class MCPSse(BaseConnection):
    url: str = Field(..., description="The SSE endpoint URL to connect to.")
    headers: dict[str, Any] | None = Field(default=None, description="Optional headers to include in the SSE request.")
    timeout: float = Field(default=5.0, description="Timeout in seconds for establishing the initial connection.")
    sse_read_timeout: float = Field(default=60 * 5, description="Timeout for reading SSE messages (in seconds).")

    def connect(self):
        """
        Establishes an SSE connection.

        Returns:
            Async context manager for the SSE client.
        """
        from mcp.client.sse import sse_client

        return sse_client(
            url=self.url,
            headers=self.headers,
            timeout=self.timeout,
            sse_read_timeout=self.sse_read_timeout,
        )


class MCPStreamableHTTP(BaseConnection):
    url: str = Field(..., description="The endpoint URL to connect to.")
    headers: dict[str, Any] | None = Field(default=None, description="Optional headers to include in the request.")
    timeout: float = Field(default=30.0, description="Timeout in seconds for establishing the initial connection.")
    sse_read_timeout: float = Field(default=60 * 5, description="Timeout for reading messages (in seconds).")

    def connect(self):
        """
        Establishes a streamable HTTP connection.

        Returns:
            Async context manager for the streamable HTTP client.
        """
        from mcp.client.streamable_http import streamablehttp_client

        return streamablehttp_client(
            url=self.url,
            headers=self.headers,
            timeout=timedelta(seconds=self.timeout),
            sse_read_timeout=timedelta(seconds=self.sse_read_timeout),
        )


class MCPEncodingErrorHandler(str, Enum):
    STRICT = "strict"
    IGNORE = "ignore"
    REPLACE = "replace"


class MCPStdio(BaseConnection):
    command: str = Field(..., description="The executable to run to start the server.")
    args: list[str] = Field(default_factory=list, description="Command-line arguments to pass to the executable.")
    env: dict[str, str] | None = Field(None, description="Environment variables for the process.")
    cwd: str | Path | None = Field(None, description="Working directory for the process.")
    encoding: str = Field(default="utf-8", description="Text encoding for communication.")
    encoding_error_handler: MCPEncodingErrorHandler = Field(
        default=MCPEncodingErrorHandler.STRICT, description="Strategy for handling encoding errors."
    )

    def connect(self):
        """
        Establishes a STDIO connection using a subprocess.

        Returns:
            Async context manager for the STDIO client.
        """
        from mcp import StdioServerParameters
        from mcp.client.stdio import stdio_client

        return stdio_client(
            StdioServerParameters(
                command=self.command,
                args=self.args,
                env=self.env,
                cwd=self.cwd,
                encoding=self.encoding,
                encoding_error_handler=self.encoding_error_handler.value,
            )
        )


class DatabricksSQL(BaseConnection):
    server_hostname: str = Field(default_factory=partial(get_env_var, "DATABRICKS_SERVER_HOSTNAME"))
    http_path: str = Field(default_factory=partial(get_env_var, "DATABRICKS_HTTP_PATH"))
    access_token: str = Field(default_factory=partial(get_env_var, "DATABRICKS_TOKEN"))

    def connect(self):
        try:
            from databricks import sql

            conn = sql.connect(
                server_hostname=self.server_hostname,
                http_path=self.http_path,
                access_token=self.access_token,
            )
            logger.debug(
                f"Connected to DataBricks using server hostname={self.server_hostname}, "
                f"http path={str(self.http_path)}"
            )
            return conn
        except Exception as e:
            raise ConnectionError(f"Failed to connect to Databricks: {str(e)}")
